# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ICNet implementation attempting to closely follow the original authors' model at:
# https://github.com/hszhao/ICNet
# Important differences:
# 1) Upsampling is nearest-neighbour instead of bilinear since it is impossible
#    to export bilinear upsampling to ONNX yet
# 2) Weight initialization is omitted because it caused mIoU degradation on CamVid

from collections import OrderedDict

import torch
import torch.nn.functional as F
from numpy import lcm
from packaging import version
from torch import nn

from examples.torch.common.example_logger import logger
from nncf.torch.utils import is_tracing_state


class ConvBN(nn.Module):
    def __init__(self, in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
        self.bn = nn.BatchNorm2d(out_channels, momentum=0.05)  # Corresponds to momentum 0.95 in Caffe notation

    def forward(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        return x


class ConvBNReLU(nn.Module):
    def __init__(self, in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
        super().__init__()
        self.convbn = ConvBN(in_channels, out_channels, kernel_size, stride, padding, dilation, bias)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs):
        x = self.convbn(inputs)
        x = self.relu(x)
        return x


class ResNetBlock(nn.Module):
    def __init__(self, in_channels, reduce_channels, increase_channels, dilation=1, stride=1):
        super().__init__()
        nonshrinking_padding = dilation
        self.conv_1x1_reduce_bnrelu = ConvBNReLU(
            in_channels, out_channels=reduce_channels, kernel_size=1, stride=stride, padding=0, dilation=1, bias=False
        )
        self.conv_3x3_bnrelu = ConvBNReLU(
            in_channels=reduce_channels,
            out_channels=reduce_channels,
            kernel_size=3,
            stride=1,
            padding=nonshrinking_padding,
            dilation=dilation,
            bias=False,
        )
        self.conv_1x1_increase_bn = ConvBN(
            in_channels=reduce_channels,
            out_channels=increase_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            dilation=1,
            bias=False,
        )
        self.need_proj = in_channels != increase_channels
        if self.need_proj:
            self.conv_1x1_proj_bn = ConvBN(
                in_channels, out_channels=increase_channels, kernel_size=1, stride=stride, padding=0, bias=False
            )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs):
        fx = self.conv_1x1_reduce_bnrelu(inputs)
        fx = self.conv_3x3_bnrelu(fx)
        fx = self.conv_1x1_increase_bn(fx)
        x = inputs
        if self.need_proj:
            x = self.conv_1x1_proj_bn(x)
        out = fx + x
        out = self.relu(out)
        return out


class ICNetBackbone(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # Naming conventions below are chosen to correspond to the
        # icnet_cityscapes_bnnomerge.prototxt file in the original ICNet Github
        # repository. Although ICNet low-resolution branch layers 'conv3', 'conv4' and 'conv5',
        # are based upon ResNet50, they rather correspond to ResNet50 layers
        # 'conv2', 'conv3' and 'conv4' respectively.

        self.conv1 = nn.Sequential(
            OrderedDict(
                [
                    (
                        "conv1_1_3x3_s2",
                        ConvBNReLU(
                            in_channels, out_channels=32, kernel_size=3, stride=2, padding=1, dilation=1, bias=False
                        ),
                    ),
                    (
                        "conv1_2_3x3",
                        ConvBNReLU(
                            in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, bias=False
                        ),
                    ),
                    (
                        "conv1_3_3x3",
                        ConvBNReLU(
                            in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, bias=False
                        ),
                    ),
                ]
            )
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Sequential(
            OrderedDict(
                [
                    ("conv2_1", ResNetBlock(64, 32, 128)),
                    ("conv2_2", ResNetBlock(128, 32, 128)),
                    ("conv2_3", ResNetBlock(128, 32, 128)),
                ]
            )
        )
        self.conv3_1 = ResNetBlock(128, 64, 256, stride=2)
        self.conv3_rest = nn.Sequential(
            OrderedDict(
                [
                    ("conv3_2", ResNetBlock(256, 64, 256)),
                    ("conv3_3", ResNetBlock(256, 64, 256)),
                    ("conv3_4", ResNetBlock(256, 64, 256)),
                ]
            )
        )
        self.conv4 = nn.Sequential(
            OrderedDict(
                [
                    ("conv4_1", ResNetBlock(256, 128, 512, dilation=2)),
                    ("conv4_2", ResNetBlock(512, 128, 512, dilation=2)),
                    ("conv4_3", ResNetBlock(512, 128, 512, dilation=2)),
                    ("conv4_4", ResNetBlock(512, 128, 512, dilation=2)),
                    ("conv4_5", ResNetBlock(512, 128, 512, dilation=2)),
                    ("conv4_6", ResNetBlock(512, 128, 512, dilation=2)),
                ]
            )
        )
        self.conv4 = nn.Sequential(
            OrderedDict(
                [
                    ("conv4_1", ResNetBlock(256, 128, 512, dilation=2)),
                    ("conv4_2", ResNetBlock(512, 128, 512, dilation=2)),
                    ("conv4_3", ResNetBlock(512, 128, 512, dilation=2)),
                    ("conv4_4", ResNetBlock(512, 128, 512, dilation=2)),
                    ("conv4_5", ResNetBlock(512, 128, 512, dilation=2)),
                    ("conv4_6", ResNetBlock(512, 128, 512, dilation=2)),
                ]
            )
        )
        self.conv5 = nn.Sequential(
            OrderedDict(
                [
                    ("conv5_1", ResNetBlock(512, 256, 1024, dilation=4)),
                    ("conv5_2", ResNetBlock(1024, 256, 1024, dilation=4)),
                    ("conv5_3", ResNetBlock(1024, 256, 1024, dilation=4)),
                ]
            )
        )

    def forward(self):
        pass


def get_backbone(backbone, in_channels):
    if backbone == "icnet":
        return ICNetBackbone(in_channels)
    raise NotImplementedError


class PyramidPooling(nn.Module):
    def __init__(self, input_size_hw, bin_dimensions=None, mode="sum"):
        super().__init__()

        if mode not in ["sum", "cat"]:
            raise NotImplementedError

        self.mode = mode
        self.input_size_hw = input_size_hw
        # self.sampling_params = {'mode': 'bilinear', 'align_corners': True}
        self.sampling_params = {"mode": "nearest"}
        if bin_dimensions is None:
            self.bin_dimensions = [1, 2, 3, 6]
        else:
            self.bin_dimensions = bin_dimensions

        # ONNX only supports exporting adaptive_avg_pool2d if the input tensor
        # height and width are exact multiples of the output_size (i.e. bin dimensions).
        # Inference-time pad calculation is also impossible to export to ONNX, therefore
        # the required padding parameters are pre-calculated here, at init.
        self.paddings = {}
        for dim in self.bin_dimensions:
            pad_h = (dim - (input_size_hw[0] % dim)) % dim
            pad_w = (dim - (input_size_hw[1] % dim)) % dim
            self.paddings[dim] = (0, pad_w, 0, pad_h)

    def forward(self, inputs):
        x = inputs.clone()

        for dim in self.bin_dimensions:
            # TODO(vshampor): adaptive_avg_pool2d is sensitive to input shapes being known at export-time
            # via shape inference. FakeQuantize ops that we export do not currently provide shape inference
            # info and it is impossible to do so in 1.13 due to a bug in PyTorch.
            # The quantized ICNet models will most likely not be exportable to ONNX until both of the following
            # is done: 1) torch >= 1.14 is used with NNCF, and 2) shape inference is added to FakeQuantize ONNX ops
            # in the corresponding `symbolic` functions.
            # See also:
            # https://discuss.pytorch.org/t/adding-shape-inference-to-custom-operator-for-onnx-exporting/163452
            # https://github.com/pytorch/pytorch/issues/81693

            pooled_feature = F.adaptive_avg_pool2d(inputs, dim)
            pooled_feature = F.interpolate(pooled_feature, self.input_size_hw, **self.sampling_params)
            if self.mode == "sum":
                x += pooled_feature
            elif self.mode == "cat":
                x = torch.cat(pooled_feature)
            else:
                raise NotImplementedError

        return x


class CascadeFeatureFusion(nn.Module):
    def __init__(self, in_channels_lowres, in_channels_highres, highres_size_hw, num_classes):
        super().__init__()
        # self.sampling_params = {'mode': 'bilinear', 'align_corners': True}
        self.sampling_params = {"mode": "nearest"}

        self.conv = ConvBN(in_channels_lowres, out_channels=128, kernel_size=3, padding=2, dilation=2, bias=False)
        self.conv_proj = ConvBN(in_channels_highres, out_channels=128, kernel_size=1, padding=0, dilation=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.classifier = nn.Conv2d(
            in_channels_lowres, out_channels=num_classes, kernel_size=1, padding=0, dilation=1, bias=True
        )
        self.highres_size_hw = highres_size_hw

    def forward(self, lowres_input, highres_input):
        upsampled = F.interpolate(lowres_input, self.highres_size_hw, **self.sampling_params)
        lr = self.conv(upsampled)
        hr = self.conv_proj(highres_input)
        x = lr + hr
        x = self.relu(x)
        if self.training:
            aux_labels = self.classifier(upsampled)
            return x, aux_labels
        return x


class ICNet(nn.Module):
    def __init__(self, input_size_hw, in_channels=3, n_classes=20, backbone="icnet"):
        super().__init__()
        self._input_size_hw = input_size_hw

        self._input_size_hw_ds2 = (self._input_size_hw[0] // 2, self._input_size_hw[1] // 2)
        self._input_size_hw_ds4 = (self._input_size_hw[0] // 4, self._input_size_hw[1] // 4)
        self._input_size_hw_ds8 = (self._input_size_hw[0] // 8, self._input_size_hw[1] // 8)
        self._input_size_hw_ds16 = (self._input_size_hw[0] // 16, self._input_size_hw[1] // 16)
        self._input_size_hw_ds32 = (self._input_size_hw[0] // 32, self._input_size_hw[1] // 32)

        # self.sampling_params = {'mode': 'bilinear', 'align_corners': True}
        self.sampling_params = {"mode": "nearest"}

        self.backbone = get_backbone(backbone, in_channels)

        self.highres_conv = nn.Sequential(
            OrderedDict(
                [
                    (
                        "conv1_sub1",
                        ConvBNReLU(in_channels, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False),
                    ),
                    (
                        "conv2_sub1",
                        ConvBNReLU(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False),
                    ),
                    (
                        "conv3_sub1",
                        ConvBNReLU(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False),
                    ),
                ]
            )
        )

        # 'conv5_4_k1' is applied immediately after pyramid pooling and before
        # cascade feature fusion
        self.conv5_4_k1 = ConvBNReLU(
            in_channels=1024, out_channels=256, kernel_size=1, stride=1, padding=0, dilation=1, bias=False
        )

        # Using pyramid pooling in 'sum' mode instead of 'cat' as in PSPNet,
        # probably because in ICNet it is immediately followed by 1x1 reduce
        # convolution anyway
        self.ppm = PyramidPooling(self._input_size_hw_ds32)
        self.cff42 = CascadeFeatureFusion(
            in_channels_lowres=256,
            in_channels_highres=256,
            highres_size_hw=self._input_size_hw_ds16,
            num_classes=n_classes,
        )
        self.cff421 = CascadeFeatureFusion(
            in_channels_lowres=128,
            in_channels_highres=32,
            highres_size_hw=self._input_size_hw_ds8,
            num_classes=n_classes,
        )
        self.conv6_cls = nn.Conv2d(128, out_channels=n_classes, kernel_size=1, padding=0, dilation=1, bias=True)

        required_alignment = 32
        for bin_dim in self.ppm.bin_dimensions:
            required_alignment = lcm(required_alignment, bin_dim)
        if (input_size_hw[0] % required_alignment) or (input_size_hw[1] % required_alignment):
            raise ValueError("ICNet may only operate on {}-aligned input resolutions".format(required_alignment))
        # Weight initialization
        # for module in self.modules():
        #     if isinstance(module, nn.Conv2d):
        #         nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
        #         if module.bias is not None:
        #             module.bias.data.zero_()
        #     elif isinstance(module, nn.BatchNorm2d):
        #         nn.init.constant_(module.weight, 1)
        #         nn.init.constant_(module.bias, 0)

    def highres_branch(self, inputs):
        x = self.highres_conv(inputs)
        return x

    def mediumres_branch(self, inputs):
        x = self.backbone.conv1(inputs)
        x = self.backbone.maxpool(x)
        x = self.backbone.conv2(x)
        x = self.backbone.conv3_1(x)
        return x

    def lowres_branch(self, inputs):
        x = self.backbone.conv3_rest(inputs)
        x = self.backbone.conv4(x)
        x = self.backbone.conv5(x)
        x = self.ppm(x)
        x = self.conv5_4_k1(x)
        return x

    def forward(self, inputs):
        data_sub1 = inputs
        features_sub1 = self.highres_branch(data_sub1)

        data_sub2 = F.interpolate(data_sub1, self._input_size_hw_ds2, **self.sampling_params)
        features_sub2 = self.mediumres_branch(data_sub2)

        # Contrary to the ICNet paper Fig.2 , the low-resolution branch does not receive separate
        # 4x-downsampled image input, but instead reuses feature maps from the medium-resolution
        # branch.

        data_sub4 = F.interpolate(features_sub2, self._input_size_hw_ds32, **self.sampling_params)
        features_sub4 = self.lowres_branch(data_sub4)

        if self.training:
            fused_features_sub42, label_scores_ds16 = self.cff42(features_sub4, features_sub2)
            fused_features_sub421, label_scores_ds8 = self.cff421(fused_features_sub42, features_sub1)

            fused_features_ds4 = F.interpolate(fused_features_sub421, self._input_size_hw_ds4, **self.sampling_params)
            label_scores_ds4 = self.conv6_cls(fused_features_ds4)

            return OrderedDict([("ds4", label_scores_ds4), ("ds8", label_scores_ds8), ("ds16", label_scores_ds16)])

        fused_features_sub42 = self.cff42(features_sub4, features_sub2)
        fused_features_sub421 = self.cff421(fused_features_sub42, features_sub1)

        fused_features_ds4 = F.interpolate(fused_features_sub421, self._input_size_hw_ds4, **self.sampling_params)
        label_scores_ds4 = self.conv6_cls(fused_features_ds4)
        label_scores = F.interpolate(label_scores_ds4, self._input_size_hw, **self.sampling_params)
        if is_tracing_state() and version.parse(torch.__version__) >= version.parse("1.1.0"):
            # While exporting, add extra post-processing layers into the graph
            # so that the model outputs class probabilities instead of class scores
            softmaxed = F.softmax(label_scores, dim=1)
            return softmaxed
        return label_scores


def icnet(num_classes, pretrained=False, **kwargs):
    model = ICNet(n_classes=num_classes, **kwargs)

    if pretrained:
        logger.warning("ICNet has no pretrained weights")

    return model
